from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import itertools
import data_proc3 as dp3
import numpy as np
from PIL import Image
import math
import clustering
import scipy
import ripser


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    percentage_correct = 100. * correct / len(test_loader.dataset)
    print('End of epoch: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset), percentage_correct))

    return test_loss, percentage_correct



def trainer(idx1, idx2, seed, mode='train'):
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=5, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True},
                     )

    transform=transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.1307,), (0.3081,))
       ])

    if mode == 'train':
        dataset1 = dp3.two_KMNIST(idx1, idx2, transform)
        train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)
        test_loader = torch.utils.data.DataLoader(dataset1, **kwargs)

        model = Net().to(device)
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
        for epoch in range(1, args.epochs + 1):
            train(args, model, device, train_loader, optimizer, epoch)
            test(model, device, test_loader)
            scheduler.step()

        file = 'models/KMNIST/mnist_cnn_' + str(idx1) + str(idx2) + '_' + str(seed) + '.pt'
        torch.save(model.state_dict(), file)

    else:
        # combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))
        combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]
        # r = np.load('clusters/r_values_D1_10_pts_10_iter_full.npy')

        model_idx1 = idx1
        model_idx2 = idx2

        filename = 'models/KMNIST/mnist_cnn_' + str(model_idx1) + str(model_idx2) + '_' + str(seed) + '.pt'
        model = Net()
        model.load_state_dict(torch.load(filename))

        loss = []
        pc = []

        for seed in range(1):
            for labels in combin:
                task_idx1 = labels[0]
                task_idx2 = labels[1]
                dataset = dp3.two_KMNIST(task_idx1, task_idx2, transform)

                test_loader = torch.utils.data.DataLoader(dataset, **kwargs)
                test_loss, percentage_correct = test(model, device, test_loader)

                loss.append(test_loss)
                pc.append(percentage_correct)

        return loss, pc


def train_all():
    # combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))
    combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]

    for seed in [1]:
        for labels in combin:
            print("\nSeed: " + str(seed) + ", digits: " + str(labels))
            trainer(labels[0], labels[1], seed)


def compute_taskwise_loss():

    combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]

    for seed in [0]:
        tot_loss, tot_pc = [], []
        for labels in combin:
            loss, pc = trainer(labels[0], labels[1], seed, mode='test')
            tot_loss.append(np.array(loss))
            tot_pc.append(np.array(pc))

        print(tot_loss)
        print(tot_pc)

        np.save('results/kmnist_loss_model_0vAll_seed_' + str(seed) + '.npy', np.array(tot_loss))
        np.save('results/kmnist_pc_model_0vAll_seed_' + str(seed) + '.npy', np.array(tot_pc))


def compute_decision_boundaries():
    # combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))
    combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]

    for seed in [1, 2]:
        for indices in combin:

            print("Computing DB for Seed: " + str(seed) + ", digits: " + str(indices))

            idx1 = indices[0]
            idx2 = indices[1]

            filename = 'models/KMNIST/mnist_cnn_' + str(idx1) + str(idx2) + '_' + str(seed) + '.pt'
            model = Net()
            model.load_state_dict(torch.load(filename))

            data, labels = dp3.load_data(idx1, idx2, 'No')

            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])

            db = []

            for i in range(len(data)):
                img = data[i]

                img = Image.fromarray(img.numpy(), mode='L')
                img = transform(img)

                result = math.exp(model.forward(img[None, ...]).detach().numpy()[0][0])
                db_point = np.append(data[i].flatten().numpy(), result)
                db.append(db_point)

            decision_boundary = np.array(db)

            filename = 'boundaries/kmnist/notransform_mnist_cnn_' + str(idx1) + str(idx2) + '_' + str(seed) + '.npy'
            np.save(filename, decision_boundary)


def cluster(D, k=10, max_iter=10, verbose=True):
    print("Clustering...")
    r, M = clustering.pd_fuzzy(D, k, verbose=verbose, max_iter=max_iter)

    # np.save('clusters/r_values_D1_true_0_only.npy', r)
    # np.save('clusters/M_values_D1_true_0_only.npy', M)

    # print(r)
    return r, M


def most_pers(dgm, n):
    # truncates diagram dgm to the n most persistent points

    dgm = sorted(dgm, key=lambda point: point[0] - point[1])
    dgm = dgm[:n]

    return dgm


def diagram(db, num_filt, maxdim):

    x = db[:, :-1]
    y = db[:, -1]

    idx = np.arange(math.floor(0.1 * len(x)))
    x = x[idx] / 255
    y = np.rint(y[idx])

    data = np.append(x, y.reshape(-1, 1), axis=1)[range(math.floor(0.5*len(x)))]

    rips = ripser.ripser(data, maxdim=maxdim)

    pd1 = rips['dgms'][1]

    return pd1


def compute_diagrams(num_filt=100, maxdim=1, num_points=10, seed=0):

    # combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))
    combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]

    D = []

    for labels in combin:
        print("Computing diagram for Seed: " + str(seed) + ", digits: " + str(labels))

        idx1 = labels[0]
        idx2 = labels[1]

        filename = 'boundaries/kmnist/notransform_mnist_cnn_' + str(idx1) + str(idx2) + '_' + str(seed) + '.npy'
        db = np.load(filename)

        pd1 = diagram(db, num_filt, maxdim)
        pd1 = most_pers(pd1, num_points)

        D.append(pd1)

    # for labels in combin:
        print("Computing diagram for true boundary, digits: " + str(labels))
        idx1 = labels[0]
        idx2 = labels[1]

        filename = 'boundaries/kmnist/true_boundary_' + str(idx1) + str(idx2) + '.npy'
        db = np.load(filename)

        pd1 = diagram(db, num_filt, maxdim)
        pd1 = most_pers(pd1, num_points)
        D.append(pd1)

    return D


def mean_performance(r=None, do_print=False):
    combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]
    pc = []
    task_to_clust = [[], [], []]

    load_r = False
    if r is None:
        r = []
        load_r = True

    for seed in range(3):
        pc_temp = np.load('results/kmnist_pc_model_0vAll_seed_' + str(seed) + '.npy')
        pc.append(pc_temp)
        # if load_r:
        #    r.append(np.load('results/r_mnist_seed_' + str(seed) + '.npy'))
        for i in range(9):
            task_to_clust[seed].append(r[seed][2 * i].argsort()[::-1][0])

    # ALL

    results = np.zeros((9, 27))
    for j in range(9):
        for seed in range(3):
            for i in range(9):
                results[j][seed*9+i] = pc[seed][i][j]

    # TOP 3

    # task_to_clust[seed][task] gives the cluster centre assigned to the task
    # we want to get the top-3 models assigned to each task
    # so start by fixing seed and task
    results_top3 = np.zeros((9, 9))
    for seed in range(3):
        for j in range(9):
            centre = task_to_clust[seed][j]
            # centre corresponds to the centre for task j (i.e., 0v(j+1))
            # now we want the top 3 models assigned to this centre

            cand = []
            for i in range(9):
                cand.append(r[seed][i*2][centre])
            cand = np.array(cand)
            top3_idx = cand.argsort()[-3:][::-1]

            # then top3_idx gives the model with the largest assignment to each task
            # so we simply need the average of pc[top3_idx][j] over all seeds

            for i in range(len(top3_idx)):
                results_top3[j][seed*3+i] = pc[seed][top3_idx[i]][j]

    # TOP 2

    # task_to_clust[seed][task] gives the cluster centre assigned to the task
    # we want to get the top-3 models assigned to each task
    # so start by fixing seed and task
    results_top2 = np.zeros((9, 6))
    for seed in range(3):
        for j in range(9):
            centre = task_to_clust[seed][j]
            # centre corresponds to the centre for task j (i.e., 0v(j+1))
            # now we want the top 3 models assigned to this centre

            cand = []
            for i in range(9):
                cand.append(r[seed][i * 2][centre])
            cand = np.array(cand)
            top2_idx = cand.argsort()[-2:][::-1]

            # then top3_idx gives the model with the largest assignment to each task
            # so we simply need the average of pc[top3_idx][j] over all seeds

            for i in range(len(top2_idx)):
                results_top2[j][seed * 2 + i] = pc[seed][top2_idx[i]][j]

    # BOTTOM 3

    results_bot3 = np.zeros((9, 9))
    for seed in range(3):
        for j in range(9):
            centre = task_to_clust[seed][j]

            cand = []
            for i in range(9):
                cand.append(r[seed][i * 2][centre])
            cand = np.array(cand)
            bot3_idx = cand.argsort()[:3]

            for i in range(len(bot3_idx)):
                results_bot3[j][seed * 3 + i] = pc[seed][bot3_idx[i]][j]

    # TOP 1

    # task_to_clust[seed][task] gives the cluster centre assigned to the task
    # we want to get the top-3 models assigned to each task
    # so start by fixing seed and task
    results_top1 = np.zeros((9, 3))
    for seed in range(3):
        for j in range(9):
            centre = task_to_clust[seed][j]
            # centre corresponds to the centre for task j (i.e., 0v(j+1))
            # now we want the top 3 models assigned to this centre

            cand = []
            for i in range(9):
                cand.append(r[seed][i * 2][centre])
            cand = np.array(cand)
            top1_idx = cand.argsort()[-1:][::-1]

            # then top3_idx gives the model with the largest assignment to each task
            # so we simply need the average of pc[top3_idx][j] over all seeds

            for i in range(len(top1_idx)):
                results_top1[j][seed + i] = pc[seed][top1_idx[i]][j]

    # RESULTS

    diff_top1_avg = []
    diff_top2_avg = []
    diff_top3_avg = []
    diff_top3_bot3 = []

    diff_top1_avg_pc = []
    diff_top2_avg_pc = []
    diff_top3_avg_pc = []

    for i in range(9):
        mean_bot3 = np.mean(results_bot3[i])
        sem_bot3 = scipy.stats.sem(results_bot3[i])

        mean = np.mean(results[i])
        sem = scipy.stats.sem(results[i])

        mean_top3 = np.mean(results_top3[i])
        sem_top3 = scipy.stats.sem(results_top3[i])

        mean_top2 = np.mean(results_top2[i])
        sem_top2 = scipy.stats.sem(results_top2[i])

        mean_top1 = np.mean(results_top1[i])
        sem_top1 = scipy.stats.sem(results_top1[i])

        if do_print:
            print('Task bot3: 0v' + str(i + 1) + ', mean = ' + str(mean_bot3) + ' +- ' + str(sem_bot3))
            print('Task all : 0v' + str(i + 1) + ', mean = ' + str(mean) + ' +- ' + str(sem))
            print('Task top3: 0v' + str(i + 1) + ', mean = ' + str(mean_top3) + ' +- ' + str(sem_top3))

            print(mean_bot3)
            print(mean)
            print(mean_top3)
            print(mean_top2)
            print(mean_top1)

        diff_top1_avg.append(mean_top1 - mean)
        diff_top2_avg.append(mean_top2 - mean)
        diff_top3_avg.append(mean_top3 - mean)
        diff_top3_bot3.append(mean_top3 - mean_bot3)

        diff_top1_avg_pc.append((mean_top1 - mean)/mean * 100)
        diff_top2_avg_pc.append((mean_top2 - mean) / mean * 100)
        diff_top3_avg_pc.append((mean_top3 - mean) / mean * 100)

    print('Top3 vs avg: ' + str(np.mean(np.array(diff_top3_avg))) + ' +- ' + str(scipy.stats.sem(diff_top3_avg)))
    print('Top2 vs avg: ' + str(np.mean(np.array(diff_top2_avg))) + ' +- ' + str(scipy.stats.sem(diff_top2_avg)))
    print('Top1 vs avg: ' +str(np.mean(np.array(diff_top1_avg))) + ' +- ' + str(scipy.stats.sem(diff_top1_avg)))

    print('Top3 vs avg %: ' + str(np.mean(np.array(diff_top3_avg_pc))) + ' +- ' + str(scipy.stats.sem(diff_top3_avg_pc)))
    print('Top2 vs avg %: ' + str(np.mean(np.array(diff_top2_avg_pc))) + ' +- ' + str(scipy.stats.sem(diff_top2_avg_pc)))
    print('Top1 vs avg %: ' + str(np.mean(np.array(diff_top1_avg_pc))) + ' +- ' + str(scipy.stats.sem(diff_top1_avg_pc)))
    # print(diff_top3_avg)
    # print(diff_top2_avg)
    # print(diff_top1_avg)
    print()


def compute_clusters(iter_val=5, save=False, print_results=False):
    D = []
    for seed in range(3):
        D.append(compute_diagrams(num_filt=100, maxdim=1, num_points=25, seed=seed))

    for iterations in range(2, 20):
        r_accum = []
        x_all = []
        y_all = []
        for seed in range(3):

            # for iterations in range(1, 20):
            # iterations = 20

            r, M = cluster(D[seed], k=9, max_iter=iterations, verbose=False)

            if print_results:
                # print(np.around(r, 3))

                for j in range(len(r)):
                    print(r[j].argsort()[-3:][::-1])


            r_accum.append(r)

        print("ITERATIONS: " + str(iterations))
        print(mean_performance(r_accum))
    return r_accum, M


def compute_true_boundaries():
    combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))
    for indices in combin:
        idx1 = indices[0]
        idx2 = indices[1]

        data, labels = dp3.load_data(idx1, idx2, 'No')
        db = []
        for i in range(len(data)):
            img = data[i]
            true_label = abs(labels[i].numpy()-1)
            db_point = np.append(img.flatten().numpy(), true_label)

            db.append(db_point)

        decision_boundary = np.array(db)
        filename = 'boundaries/kmnist/true_boundary_' + str(idx1) + str(idx2) + '.npy'
        np.save(filename, decision_boundary)


if __name__ == '__main__':
    train_all()
    compute_taskwise_loss()
    compute_decision_boundaries()
    compute_true_boundaries()
    compute_clusters()
